{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from sklearn.cross_validation import train_test_split\n",
    "from sklearn.cross_validation import cross_val_score\n",
    "from sklearn.preprocessing import scale\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import classification_report\n",
    "from PIL import Image\n",
    "\n",
    "X = []\n",
    "y = []\n",
    "\n",
    "for dirpath, _, filenames in os.walk('att-faces/orl_faces'):\n",
    "    for filename in filenames:\n",
    "        if filename[-3:] == 'pgm':\n",
    "            img = Image.open(os.path.join(dirpath, filename)).convert('L')\n",
    "            arr = np.array(img).reshape(10304).astype('float32') / 255.\n",
    "            X.append(arr)\n",
    "            y.append(dirpath)\n",
    "\n",
    "X = scale(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=11)\n",
    "pca = PCA(n_components=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(300, 10304)\n",
      "(300, 150)\n",
      "Cross validation accuracy: 0.807660834984\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "att-faces/orl_faces/s1       0.50      1.00      0.67         1\n",
      "att-faces/orl_faces/s10       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s11       1.00      0.67      0.80         3\n",
      "att-faces/orl_faces/s12       1.00      1.00      1.00         5\n",
      "att-faces/orl_faces/s13       0.00      0.00      0.00         0\n",
      "att-faces/orl_faces/s14       1.00      1.00      1.00         4\n",
      "att-faces/orl_faces/s16       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s17       0.67      1.00      0.80         2\n",
      "att-faces/orl_faces/s18       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s19       0.83      1.00      0.91         5\n",
      "att-faces/orl_faces/s2       0.33      1.00      0.50         1\n",
      "att-faces/orl_faces/s20       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s21       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s22       1.00      1.00      1.00         1\n",
      "att-faces/orl_faces/s23       0.67      1.00      0.80         2\n",
      "att-faces/orl_faces/s24       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s25       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s26       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s27       1.00      1.00      1.00         1\n",
      "att-faces/orl_faces/s28       1.00      0.50      0.67         4\n",
      "att-faces/orl_faces/s29       1.00      1.00      1.00         5\n",
      "att-faces/orl_faces/s3       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s30       1.00      0.67      0.80         3\n",
      "att-faces/orl_faces/s31       0.75      1.00      0.86         3\n",
      "att-faces/orl_faces/s32       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s34       1.00      0.83      0.91         6\n",
      "att-faces/orl_faces/s35       0.50      0.33      0.40         3\n",
      "att-faces/orl_faces/s36       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s37       1.00      0.75      0.86         4\n",
      "att-faces/orl_faces/s38       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s39       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s4       1.00      0.75      0.86         4\n",
      "att-faces/orl_faces/s40       0.00      0.00      0.00         0\n",
      "att-faces/orl_faces/s5       1.00      0.67      0.80         3\n",
      "att-faces/orl_faces/s6       1.00      1.00      1.00         1\n",
      "att-faces/orl_faces/s7       1.00      1.00      1.00         3\n",
      "att-faces/orl_faces/s8       1.00      1.00      1.00         2\n",
      "att-faces/orl_faces/s9       1.00      1.00      1.00         1\n",
      "\n",
      "avg / total       0.94      0.90      0.91       100\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gavin/classpass-activity-tagger/venv/local/lib/python2.7/site-packages/sklearn/metrics/classification.py:1076: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples.\n",
      "  'recall', 'true', average, warn_for)\n"
     ]
    }
   ],
   "source": [
    "X_train_reduced = pca.fit_transform(X_train)\n",
    "X_test_reduced = pca.transform(X_test)\n",
    "print(X_train.shape)\n",
    "print(X_train_reduced.shape)\n",
    "classifier = LogisticRegression()\n",
    "accuracies = cross_val_score(classifier, X_train_reduced, y_train)\n",
    "print('Cross validation accuracy: %s' % np.mean(accuracies))\n",
    "classifier.fit(X_train_reduced, y_train)\n",
    "predictions = classifier.predict(X_test_reduced)\n",
    "print(classification_report(y_test, predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
